{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1bf647a4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "os.environ['HF_HOME'] =\"/workspace/cache\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1c8d7e4b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import safetensors\n",
    "import torch\n",
    "from glob import glob\n",
    "from transformers import Qwen2AudioForConditionalGeneration\n",
    "from safetensors import safe_open\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b04d3592",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers.trainer_utils import get_last_checkpoint\n",
    "\n",
    "latest = get_last_checkpoint(\"lora-embedding-128-qwen2audio-7b\")\n",
    "latest"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d7f970c2",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "ori_model = Qwen2AudioForConditionalGeneration.from_pretrained('Qwen/Qwen2-Audio-7B-Instruct')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "81404b20",
   "metadata": {},
   "outputs": [],
   "source": [
    "state_dict = ori_model.state_dict()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7686ba14",
   "metadata": {},
   "outputs": [],
   "source": [
    "files = glob(os.path.join(latest, '*.safetensors'))\n",
    "for f in files:\n",
    "    print(f)\n",
    "    f = safe_open(f, framework=\"pt\", device='cpu')\n",
    "    keys = f.keys()\n",
    "    keys = sorted(list(set([k.split('.lora')[0] for k in keys if '.lora' in k])))\n",
    "\n",
    "    for k in tqdm(keys):\n",
    "        if 'lm_head' in k:\n",
    "            actual_k = 'language_model.lm_head.weight'\n",
    "        else:\n",
    "            actual_k = k.replace('.base_model.model.model.', '.model.') + '.weight'\n",
    "        if 'embed_tokens' in k:\n",
    "            post_A = '.lora_embedding_A.default'\n",
    "            post_B = '.lora_embedding_B.default'\n",
    "        else:\n",
    "            post_A = '.lora_A.default.weight'\n",
    "            post_B = '.lora_B.default.weight'\n",
    "        A = k + post_A\n",
    "        B = k + post_B\n",
    "\n",
    "        W = state_dict[actual_k]\n",
    "        if 'embed_tokens' not in k:\n",
    "            W = W.t()\n",
    "\n",
    "        A = f.get_tensor(A).type(W.dtype)\n",
    "        B = f.get_tensor(B).type(W.dtype)\n",
    "\n",
    "        with torch.no_grad():\n",
    "            W.addmm_(A.t(), B.t(), alpha = 2)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
